import matplotlib.pyplot as plt # type: ignore
import numpy as np # type: ignore
import pandas as pd # type: ignore
# from matplotlib.ticker import PercentFormatter
from matplotlib.ticker import MultipleLocator

def plot_colormaps_diff(cases, plot_labels, color_dict, save_name, show_legend = False):
    freqs = [4000, 4500, 5000, 5500, 6000, 6500, 7000]

    box_df = pd.read_csv('comsol_raw/box_colormap.csv', skiprows = 8, index_col = [0,1])

    fig, axarrarr = plt.subplots(2, 3, sharex='col', sharey='row', figsize = [8,3.2])
    axarr = axarr[0,:]
    for casei, case in zip(range(len(cases)), cases):
        print('case = '+case)
        # if case == '_cont':
        #     continue
                
        slice_df = pd.read_csv('comsol_raw/slice_colormap'+ case +'.csv', skiprows = 8, index_col = [0,1])

        for j, dir in zip(np.arange(3), ['Z','X','Y']):
            print(j, dir)
            if dir == 'Y' and case == '_cont':
                continue
            
            for i, freqi in zip(np.arange(len(freqs)), freqs):
                print(i, freqi)
                                        
                box_z = np.asarray(box_df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freqi)])
                slice_z = np.asarray(slice_df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freqi)])

                for k in np.arange(len(box_z)):
                    if box_z[k] < 1e-5:                    
                        box_z[k] = np.nan

                z = 20*np.log10(slice_z/box_z)
                z = z[~np.isnan(z)]
            
                axarr[j].boxplot(z, positions=[freqi/1000], widths=0.25, patch_artist=True,
                    showmeans=False, showfliers=False,
                    medianprops={"color": "white", "linewidth": 0.5},
                    boxprops={"facecolor": color_dict[case], "edgecolor": "none",
                          "linewidth": 0.5},
                    whiskerprops={"color": color_dict[case], "linewidth": 1.5},
                    capprops={"color": color_dict[case], "linewidth": 1.5})
            
        axarr[0].plot([],[], linestyle = '', marker = 's', color = color_dict[case], label = plot_labels[casei])
        
            
        
    for ax in axarr:
        ax.plot([3.5, 7.5],[0,0], linestyle = '-', color = 'dimgray', linewidth = 1)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.set_xlabel('frequency (kHz)')
        ax.set_xlim(3.5, 7.5)
        ax.xaxis.set_major_locator(MultipleLocator(1))
        ax.xaxis.set_major_formatter('{x:.0f}')
        ax.xaxis.set_minor_locator(MultipleLocator(0.5))

    if show_legend:
        axarr[0].legend()
    axarr[0].set_ylabel('dB re. full-length result')
    fig.tight_layout()
    plt.savefig(save_name, transparent = True)
    plt.show()
    plt.close()

def plot_colormaps_diff_phase(cases, color_dict, save_name):
    freqs = [4000, 4500, 5000, 5500, 6000, 6500, 7000]

    box_shift = np.zeros(len(freqs))
    slice_shift = np.zeros(len(freqs))
    
    box_df = pd.read_csv('comsol_raw/box_colormap_phase.csv', skiprows = 8, index_col = [0,1])

    fig, axarr = plt.subplots(1, 3, sharex='col', sharey='row', figsize = [8,3])
    for casei, case in zip(range(len(cases)), cases):
        print('case = '+case)
                
        slice_df = pd.read_csv('comsol_raw/slice_colormap_phase'+ case +'.csv', skiprows = 8, index_col = [0,1])
        
        for j, dir in zip(np.arange(3), ['Z','X','Y']):
            print(j, dir)
            if dir == 'Y' and case == '_cont':
                continue
            
            for i, freqi in zip(np.arange(len(freqs)), freqs):
                print(i, freqi)
                                        
                box_z = np.asarray(box_df.loc[:,  'solid.uPhase' + dir + ' (rad) @ freq='+str(freqi)])                    
                slice_z = np.asarray(slice_df.loc[:, 'solid.uPhase' + dir + ' (rad) @ freq='+str(freqi)])

                if j == 0:
                    k = 0
                    while np.isnan(box_z[k]):
                        k = k + 1        
                    box_shift[i] = box_z[k]
                    print(k, box_shift[i])

                    k = 0
                    while np.isnan(slice_z[k]):
                        k = k + 1        
                    slice_shift[i] = slice_z[k]
                    print(k, slice_shift[i])
                
                box_z = box_z - box_shift[i]
                slice_z = slice_z - slice_shift[i]

                for z in [box_z, slice_z]:
                    if j == 0:
                        for k in np.arange(len(z)):
                            if z[k] < -1:
                                z[k] = z[k]+2*np.pi                        
                            if z[k] > np.pi:
                                z[k] = z[k]-2*np.pi              
                    elif j == 1:
                        for k in np.arange(len(z)):
                            if z[k] < -1:
                                z[k] = z[k]+2*np.pi
                            if z[k] > 2:
                                z[k] = z[k]-2*np.pi              
                    elif j == 2:
                        for k in np.arange(len(z)):
                            if z[k] > 1:
                                z[k] = z[k]-2*np.pi

                z = (box_z - slice_z)/2/np.pi
                z = z[~np.isnan(z)]

                axarr[j].boxplot(z, positions=[freqi/1000], widths=0.25, patch_artist=True,
                    showmeans=False, showfliers=False,
                    medianprops={"color": "white", "linewidth": 0.5},
                    boxprops={"facecolor": color_dict[case], "edgecolor": "none",
                          "linewidth": 0.5},
                    whiskerprops={"color": color_dict[case], "linewidth": 1.5},
                    capprops={"color": color_dict[case], "linewidth": 1.5})
        
    for ax in axarr:
        ax.plot([3.5, 7.5],[0,0], linestyle = '-', color = 'dimgray', linewidth = 1)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.set_xlabel('frequency (kHz)')
        ax.set_xlim(3.5, 7.5)
        ax.xaxis.set_major_locator(MultipleLocator(1))
        ax.xaxis.set_major_formatter('{x:.0f}')
        ax.xaxis.set_minor_locator(MultipleLocator(0.5))        

    axarr[0].set_ylabel('difference re. full-length (cycle)')
    fig.tight_layout()
    plt.savefig(save_name, transparent = True)
    plt.show()
    plt.close()


def plot_cases(study):

    color_dict = {'_st': 'black',
                  '_sv': 'tab:green',
                  '':'tab:purple',
                  '_k_x_2': 'peachpuff',
                  '_k_x_3': 'lightsalmon',
                  '_k_div_3': 'lightblue',
                  '_cont':'steelblue'}
    
    if study == '_locs':
        cases = ['_st', '_sv','', '_st']
        plot_labels = ['ST stimulation', 'SV stimulation', 'SV & ST stml','']
    elif study == '_ks':
        cases = ['_st','_k_x_2', '_k_x_3','_k_div_3','_k_x_2', '_cont', '_k_div_3','_st']
        plot_labels = ['slice k', 'slice 2*k', 'slice 3*k', 'slice k/3','', 'continuity', '','']
    
    
    plot_colormaps_diff(cases, plot_labels, color_dict, 'manufigs/colormap_diff' + study + '.pdf', True)
    plot_colormaps_diff_phase(cases, color_dict, 'manufigs/colormap_diff_phase' + study + '.pdf')


plot_cases('_locs')
plot_cases('_ks')